#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#

from dataclasses import dataclass, MISSING
from typing import Dict, Iterable, Tuple, Type

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import (
    AdditiveGaussianWrapper,
    Delta,
    ProbabilisticActor,
    TanhDelta,
)
from torchrl.objectives import DDPGLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig


class Iddpg(Algorithm):
    """Same as :class:`~benchmarl.algorithms.Maddpg` (from `https://arxiv.org/abs/1706.02275 <https://arxiv.org/abs/1706.02275>`__) but with decentralized critics.

    Args:
        share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
        loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
        delay_value (bool): whether to separate the target value networks from the value networks used for
            data collection.
        use_tanh_mapping (bool): if ``True``, use squash actions (output by the policy) into the action range, otherwise
            clip them.

    """

    def __init__(
        self,
        share_param_critic: bool,
        loss_function: str,
        delay_value: bool,
        use_tanh_mapping: bool,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.share_param_critic = share_param_critic
        self.delay_value = delay_value
        self.loss_function = loss_function
        self.use_tanh_mapping = use_tanh_mapping

    #############################
    # Overridden abstract methods
    #############################

    def _get_loss(
        self, group: str, policy_for_loss: TensorDictModule, continuous: bool
    ) -> Tuple[LossModule, bool]:
        if continuous:
            # Loss
            loss_module = DDPGLoss(
                actor_network=policy_for_loss,
                value_network=self.get_value_module(group),
                delay_value=self.delay_value,
                loss_function=self.loss_function,
            )
            loss_module.set_keys(
                state_action_value=(group, "state_action_value"),
                reward=(group, "reward"),
                priority=(group, "td_error"),
                done=(group, "done"),
                terminated=(group, "terminated"),
            )

            loss_module.make_value_estimator(
                ValueEstimators.TD0, gamma=self.experiment_config.gamma
            )
            return loss_module, True
        else:
            raise NotImplementedError(
                "Iddpg is not compatible with discrete actions yet"
            )

    def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
        return {
            "loss_actor": list(loss.actor_network_params.flatten_keys().values()),
            "loss_value": list(loss.value_network_params.flatten_keys().values()),
        }

    def _get_policy_for_loss(
        self, group: str, model_config: ModelConfig, continuous: bool
    ) -> TensorDictModule:
        if continuous:
            n_agents = len(self.group_map[group])
            logits_shape = list(self.action_spec[group, "action"].shape)
            actor_input_spec = CompositeSpec(
                {group: self.observation_spec[group].clone().to(self.device)}
            )
            actor_output_spec = CompositeSpec(
                {
                    group: CompositeSpec(
                        {"param": UnboundedContinuousTensorSpec(shape=logits_shape)},
                        shape=(n_agents,),
                    )
                }
            )
            actor_module = model_config.get_model(
                input_spec=actor_input_spec,
                output_spec=actor_output_spec,
                agent_group=group,
                input_has_agent_dim=True,
                n_agents=n_agents,
                centralised=False,
                share_params=self.experiment_config.share_policy_params,
                device=self.device,
                action_spec=self.action_spec,
            )

            policy = ProbabilisticActor(
                module=actor_module,
                spec=self.action_spec[group, "action"],
                in_keys=[(group, "param")],
                out_keys=[(group, "action")],
                distribution_class=TanhDelta if self.use_tanh_mapping else Delta,
                distribution_kwargs={
                    "min": self.action_spec[(group, "action")].space.low,
                    "max": self.action_spec[(group, "action")].space.high,
                }
                if self.use_tanh_mapping
                else {},
                return_log_prob=False,
                safe=not self.use_tanh_mapping,
            )
            return policy
        else:
            raise NotImplementedError(
                "Iddpg is not compatible with discrete actions yet"
            )

    def _get_policy_for_collection(
        self, policy_for_loss: TensorDictModule, group: str, continuous: bool
    ) -> TensorDictModule:
        return AdditiveGaussianWrapper(
            policy_for_loss,
            annealing_num_steps=self.experiment_config.get_exploration_anneal_frames(
                self.on_policy
            ),
            action_key=(group, "action"),
            sigma_init=self.experiment_config.exploration_eps_init,
            sigma_end=self.experiment_config.exploration_eps_end,
        )

    def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
        keys = list(batch.keys(True, True))
        group_shape = batch.get(group).shape

        nested_done_key = ("next", group, "done")
        nested_terminated_key = ("next", group, "terminated")
        nested_reward_key = ("next", group, "reward")

        if nested_done_key not in keys:
            batch.set(
                nested_done_key,
                batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
            )
        if nested_terminated_key not in keys:
            batch.set(
                nested_terminated_key,
                batch.get(("next", "terminated"))
                .unsqueeze(-1)
                .expand((*group_shape, 1)),
            )

        if nested_reward_key not in keys:
            batch.set(
                nested_reward_key,
                batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)),
            )

        return batch

    #####################
    # Custom new methods
    #####################

    def get_value_module(self, group: str) -> TensorDictModule:
        n_agents = len(self.group_map[group])
        modules = []

        critic_input_spec = CompositeSpec(
            {
                group: self.observation_spec[group]
                .clone()
                .update(self.action_spec[group])
            }
        )
        critic_output_spec = CompositeSpec(
            {
                group: CompositeSpec(
                    {
                        "state_action_value": UnboundedContinuousTensorSpec(
                            shape=(n_agents, 1)
                        )
                    },
                    shape=(n_agents,),
                )
            }
        )

        modules.append(
            self.critic_model_config.get_model(
                input_spec=critic_input_spec,
                output_spec=critic_output_spec,
                n_agents=n_agents,
                centralised=False,
                input_has_agent_dim=True,
                agent_group=group,
                share_params=self.share_param_critic,
                device=self.device,
                action_spec=self.action_spec,
            )
        )

        return TensorDictSequential(*modules)


@dataclass
class IddpgConfig(AlgorithmConfig):
    """Configuration dataclass for :class:`~benchmarl.algorithms.Iddpg`."""

    share_param_critic: bool = MISSING
    loss_function: str = MISSING
    delay_value: bool = MISSING
    use_tanh_mapping: bool = MISSING

    @staticmethod
    def associated_class() -> Type[Algorithm]:
        return Iddpg

    @staticmethod
    def supports_continuous_actions() -> bool:
        return True

    @staticmethod
    def supports_discrete_actions() -> bool:
        return False

    @staticmethod
    def on_policy() -> bool:
        return False
